"""Plot experiment results."""

import json
import os

from absl import app
from absl import flags
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils.extmath import randomized_svd

plt.style.use('seaborn')

flags.DEFINE_string('input_dir', 'experiments', 'Dir with measurements')
flags.DEFINE_string('output_file', None, 'File with plots')
flags.DEFINE_enum('metric', 'loss', ['loss', 'time'], 'Metric to plot')
FLAGS = flags.FLAGS

DS = ['mnist', 'fashion_mnist', 'smallnorb', 'colorectal_histology']

def plot(ax, results, title=''):
  algos = ['svd_w', 'adam', 'em', 'greedy', 'sample', 'svd']
  for algo in algos:
    result = results[algo][FLAGS.metric]
    ranks = range(1, len(result) + 1)
    alpha = 1 if algo == 'svd_w' else 0.5
    ax.plot(ranks, result, linewidth=3, alpha=alpha, label=algo)

  ylabel = FLAGS.metric.capitalize()
  if FLAGS.metric == 'time':
    ylabel += ' (seconds)'
  ax.set_ylabel(ylabel, fontsize=18)
  ax.set_xlabel('Rank', fontsize=18)
  ax.legend(fontsize=16)
  ax.set_title(title, fontsize=20)


def main(argv) -> None:
  plt.rc('xtick', labelsize=14)    # fontsize of the tick labels
  plt.rc('ytick', labelsize=14)    # fontsize of the tick labels
  fig = plt.figure(figsize=(16, 10))
  axs = fig.subplots(nrows=2, ncols=2)

  results = {}
  for ds in DS:
    input_file = os.path.join(FLAGS.input_dir, f'{ds}.json')
    if not os.path.exists(input_file):
      raise ValueError(f'Path {input_file} does not exist')
    with open(input_file, 'r') as fp:
      results[ds] = json.load(fp)

  plot(axs[0][0], results['mnist'], 'mnist')
  plot(axs[0][1], results['fashion_mnist'], 'fashion_mnist')
  plot(axs[1][0], results['smallnorb'], 'smallnorb')
  plot(axs[1][1], results['colorectal_histology'], 'colorectal_histology')

  fig.tight_layout(h_pad=3, w_pad=3)

  if FLAGS.output_file is not None:
    plt.savefig(FLAGS.output_file)
  else:
    plt.show()


if __name__ == '__main__':
  app.run(main)
